{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "# 初始化tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-0.5B-Instruct')\n",
    "\n",
    "from recurrent.interface import AsyncOutput\n",
    "from recurrent.chat_template.utils import set_chat_template\n",
    "\n",
    "from recurrent.async_generation_manager import AsyncLLMGenerationManager\n",
    "class Dummy(AsyncLLMGenerationManager):\n",
    "    def __init__(self, tokenizer):\n",
    "        self.tokenizer = tokenizer\n",
    "        set_chat_template(tokenizer)\n",
    "\n",
    "manager = Dummy(\n",
    "    tokenizer,\n",
    ")\n",
    "messages=[\n",
    "    {\"role\": \"system\", \"content\": \"System.\"},\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"User\",\n",
    "    },\n",
    "    {\n",
    "        \"role\": \"assistant\",\n",
    "        \"content\": \"Assistant\",\n",
    "    },\n",
    "    {\n",
    "        \"role\": \"tool\",\n",
    "        \"content\": \"66666\",\n",
    "    },\n",
    "    {\n",
    "        \"role\": \"assistant\",\n",
    "        \"content\": \"777\",\n",
    "    }\n",
    "]\n",
    "messages2=[\n",
    "    {\"role\": \"system\", \"content\": \"System.\"},\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"UserUser\",\n",
    "    },\n",
    "    {\n",
    "        \"role\": \"assistant\",\n",
    "        \"content\": \"AssistantAssistant\",\n",
    "    },\n",
    "    {\n",
    "        \"role\": \"tool\",\n",
    "        \"content\": \"66666\",\n",
    "    },\n",
    "    {\n",
    "        \"role\": \"assistant\",\n",
    "        \"content\": \"777777\",\n",
    "        \"finished\": False\n",
    "    }\n",
    "]\n",
    "import torch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'prompts': array([array([151644,   8948,    198,   2320,     13, 151645,    271, 151644,\n",
       "                   872,    198,   1474,   1474, 151645,    198, 151644,  77091,\n",
       "                   198])                                                       ],\n",
       "        dtype=object),\n",
       "  'responses': array([array([151644,   8948,    198,   2610,    525,   1207,  16948,     11,\n",
       "                  3465,    553,  54364,  14817,     13,   1446,    525,    264,\n",
       "                 10950,  17847,     13, 151645,    198, 151644,    872,    198,\n",
       "                  9571, 151645,    198, 151644,  77091,    198,  71703,  71703,\n",
       "                151645,    198, 151644,    872,    198,     27,  14172,   9655,\n",
       "                   397,     21,     21,     21,     21,     21,    198,    522,\n",
       "                 14172,   9655,     29, 151645,    198, 151644,  77091,    198,\n",
       "                    22,     22,     22,     22,     22,     22])               ],\n",
       "        dtype=object),\n",
       "  'loss_mask': array([array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "                0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1])           ],\n",
       "        dtype=object),\n",
       "  'sample_index': array([0]),\n",
       "  'final_mask': array([ True])},\n",
       " {'prompts': array([array([151644,   8948,    198,   2320,     13, 151645,    271, 151644,\n",
       "                   872,    198,   1474, 151645,    198, 151644,  77091,    198]),\n",
       "         array([151644,   8948,    198,   2320,     13, 151645,    271, 151644,\n",
       "                   872,    198,   1474,   1474, 151645,    198, 151644,  77091,\n",
       "                   198])                                                       ],\n",
       "        dtype=object),\n",
       "  'responses': array([array([151644,   8948,    198,   2610,    525,   1207,  16948,     11,\n",
       "                  3465,    553,  54364,  14817,     13,   1446,    525,    264,\n",
       "                 10950,  17847,     13, 151645,    198, 151644,    872,    198,\n",
       "                  9571, 151645,    198, 151644,  77091,    198,  71703, 151645,\n",
       "                   198, 151644,    872,    198,     27,  14172,   9655,    397,\n",
       "                    21,     21,     21,     21,     21,    198,    522,  14172,\n",
       "                  9655,     29, 151645,    198, 151644,  77091,    198,     22,\n",
       "                    22,     22, 151645])                                       ,\n",
       "         array([151644,   8948,    198,   2610,    525,   1207,  16948,     11,\n",
       "                  3465,    553,  54364,  14817,     13,   1446,    525,    264,\n",
       "                 10950,  17847,     13, 151645,    198, 151644,    872,    198,\n",
       "                  9571, 151645,    198, 151644,  77091,    198,  71703,  71703,\n",
       "                151645,    198, 151644,    872,    198,     27,  14172,   9655,\n",
       "                   397,     21,     21,     21,     21,     21,    198,    522,\n",
       "                 14172,   9655,     29, 151645,    198, 151644,  77091,    198,\n",
       "                    22,     22,     22,     22,     22,     22])               ],\n",
       "        dtype=object),\n",
       "  'loss_mask': array([array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "                0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1])                    ,\n",
       "         array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "                0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1])           ],\n",
       "        dtype=object),\n",
       "  'sample_index': array([1, 1]),\n",
       "  'final_mask': array([False,  True])}]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs = [AsyncOutput(\n",
    "        conversations=[messages2],\n",
    "        sample_index=torch.tensor([0]),\n",
    "        final_mask=torch.tensor([True]),\n",
    "        timing_raw={}\n",
    "    ), AsyncOutput(\n",
    "        conversations=[messages, messages2],\n",
    "        sample_index=torch.tensor([1, 1]),\n",
    "        final_mask=torch.tensor([False, True]),\n",
    "        timing_raw={}\n",
    "    )]\n",
    "outs = [manager.tokenize_output(o) for o in outputs]\n",
    "def check(out):\n",
    "    assert len(out['loss_mask'][0]) == len(out['responses'][0])\n",
    "for out in outs:\n",
    "    check(out)\n",
    "outs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataProto(batch=TensorDict(\n",
       "    fields={\n",
       "        attention_mask: Tensor(shape=torch.Size([3, 49]), device=cpu, dtype=torch.bool, is_shared=False),\n",
       "        final_mask: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.bool, is_shared=False),\n",
       "        input_ids: Tensor(shape=torch.Size([3, 49]), device=cpu, dtype=torch.int64, is_shared=False),\n",
       "        loss_mask: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.int64, is_shared=False),\n",
       "        position_ids: Tensor(shape=torch.Size([3, 49]), device=cpu, dtype=torch.int64, is_shared=False),\n",
       "        prompts: Tensor(shape=torch.Size([3, 17]), device=cpu, dtype=torch.int64, is_shared=False),\n",
       "        responses: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.int64, is_shared=False),\n",
       "        sample_index: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.int64, is_shared=False)},\n",
       "    batch_size=torch.Size([3]),\n",
       "    device=None,\n",
       "    is_shared=False), non_tensor_batch={}, meta_info={})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch = manager.concat_output(\n",
    "    outs\n",
    ")\n",
    "batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorDict(\n",
       "    fields={\n",
       "        attention_mask: Tensor(shape=torch.Size([3, 49]), device=cpu, dtype=torch.bool, is_shared=False),\n",
       "        final_mask: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.bool, is_shared=False),\n",
       "        input_ids: Tensor(shape=torch.Size([3, 49]), device=cpu, dtype=torch.int64, is_shared=False),\n",
       "        loss_mask: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.int64, is_shared=False),\n",
       "        position_ids: Tensor(shape=torch.Size([3, 49]), device=cpu, dtype=torch.int64, is_shared=False),\n",
       "        prompts: Tensor(shape=torch.Size([3, 17]), device=cpu, dtype=torch.int64, is_shared=False),\n",
       "        responses: Tensor(shape=torch.Size([3, 32]), device=cpu, dtype=torch.int64, is_shared=False),\n",
       "        sample_index: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.int64, is_shared=False)},\n",
       "    batch_size=torch.Size([3]),\n",
       "    device=None,\n",
       "    is_shared=False)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch.batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 17])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch.batch['prompts'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[0, 0],\n",
       "         [0, 1],\n",
       "         [0, 0]]),\n",
       " tensor([[False,  True],\n",
       "         [ True,  True],\n",
       "         [False,  True]]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from recurrent.utils import pad_tensor_list_to_length\n",
    "length = [1,2,1]\n",
    "pad_tensor_list_to_length(\n",
    "    [torch.arange(l) for l in length],\n",
    "    pad_token_id=0,\n",
    "    return_mask=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "prompts\n",
      "'<|im_start|>system\\nSystem.<|im_end|>\\n\\n<|im_start|>user\\nUserUser<|im_end|>\\n<|im_start|>assistant\\n'\n",
      "'<|endoftext|><|im_start|>system\\nSystem.<|im_end|>\\n\\n<|im_start|>user\\nUser<|im_end|>\\n<|im_start|>assistant\\n'\n",
      "'<|im_start|>system\\nSystem.<|im_end|>\\n\\n<|im_start|>user\\nUserUser<|im_end|>\\n<|im_start|>assistant\\n'\n",
      "responses\n",
      "'AssistantAssistant<|im_end|>\\n<|im_start|>user\\n<tool_response>\\n66666\\n</tool_response><|im_end|>\\n<|im_start|>assistant\\n777777'\n",
      "'Assistant<|im_end|>\\n<|im_start|>user\\n<tool_response>\\n66666\\n</tool_response><|im_end|>\\n<|im_start|>assistant\\n777<|im_end|><|endoftext|><|endoftext|><|endoftext|>'\n",
      "'AssistantAssistant<|im_end|>\\n<|im_start|>user\\n<tool_response>\\n66666\\n</tool_response><|im_end|>\\n<|im_start|>assistant\\n777777'\n",
      "input_ids\n",
      "'<|im_start|>system\\nSystem.<|im_end|>\\n\\n<|im_start|>user\\nUserUser<|im_end|>\\n<|im_start|>assistant\\nAssistantAssistant<|im_end|>\\n<|im_start|>user\\n<tool_response>\\n66666\\n</tool_response><|im_end|>\\n<|im_start|>assistant\\n777777'\n",
      "'<|endoftext|><|im_start|>system\\nSystem.<|im_end|>\\n\\n<|im_start|>user\\nUser<|im_end|>\\n<|im_start|>assistant\\nAssistant<|im_end|>\\n<|im_start|>user\\n<tool_response>\\n66666\\n</tool_response><|im_end|>\\n<|im_start|>assistant\\n777<|im_end|><|endoftext|><|endoftext|><|endoftext|>'\n",
      "'<|im_start|>system\\nSystem.<|im_end|>\\n\\n<|im_start|>user\\nUserUser<|im_end|>\\n<|im_start|>assistant\\nAssistantAssistant<|im_end|>\\n<|im_start|>user\\n<tool_response>\\n66666\\n</tool_response><|im_end|>\\n<|im_start|>assistant\\n777777'\n",
      "attention_mask\n",
      "tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "          True,  True,  True,  True,  True,  True,  True,  True,  True],\n",
      "        [False,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "          True,  True,  True,  True,  True,  True, False, False, False],\n",
      "        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
      "          True,  True,  True,  True,  True,  True,  True,  True,  True]])\n",
      "position_ids\n",
      "tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
      "         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
      "         36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48],\n",
      "        [ 0,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,\n",
      "         17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,\n",
      "         35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 44, 44, 44],\n",
      "        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
      "         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,\n",
      "         36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48]])\n",
      "loss_mask\n",
      "tensor([[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 1, 1, 1, 1, 0, 0, 0],\n",
      "        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "         0, 0, 1, 1, 1, 1, 1, 1]])\n",
      "sample_index\n",
      "tensor([0, 1, 1])\n",
      "final_mask\n",
      "tensor([ True, False,  True])\n"
     ]
    }
   ],
   "source": [
    "for k in batch.batch.keys():\n",
    "    print(k)\n",
    "    if k  in ['prompts', 'responses', 'input_ids']:\n",
    "        for l in tokenizer.batch_decode(batch.batch[k]):\n",
    "            print(repr(l))\n",
    "    else:\n",
    "        print(batch.batch[k])\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "fileId": "1f55fe1c-32ca-4acc-9b90-7c2ee67c22ae",
  "filePath": "/opt/tiger/verl/recurrent/test2.ipynb",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
